import pandas as pd
from datasets import Dataset, Features, Value
from nlpx.text_token.utils import get_df_text_labels


if __name__ == '__main__':
	file = '~/project/python/parttime/text_gcn/data/北方地区不安全事件统计20240331.csv'
	df = pd.read_csv(file, encoding='GBK')
	texts, labels, classes = get_df_text_labels(df, text_col='故障描述', label_col='故障标志')
	
	features = Features({
		'text': Value('string'),  # 文本特征
		'label': Value('int8')
	})
	
	dataset = Dataset.from_dict({
		'text': texts,
		'label': labels
	}, features=features)
	
	dataset_dict = dataset.shuffle(seed=42).train_test_split(test_size=0.2)
	dataset_dict.save_to_disk('data')
